{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Parameters to reproduce the paper's results**:\n",
    "* change the optimizer from SGD to Adam in `lib/models.py`,\n",
    "* change the size of the vocabulary from 1000 to 10000 in `train.keep_top_words()` below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys, os\n",
    "sys.path.insert(0, '..')\n",
    "from lib import models, graph, coarsening, utils\n",
    "\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.sparse\n",
    "import numpy as np\n",
    "import time\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "flags = tf.app.flags\n",
    "FLAGS = flags.FLAGS\n",
    "\n",
    "# Graphs.\n",
    "flags.DEFINE_integer('number_edges', 16, 'Graph: minimum number of edges per vertex.')\n",
    "flags.DEFINE_string('metric', 'cosine', 'Graph: similarity measure (between features).')\n",
    "# TODO: change cgcnn for combinatorial Laplacians.\n",
    "flags.DEFINE_bool('normalized_laplacian', True, 'Graph Laplacian: normalized.')\n",
    "flags.DEFINE_integer('coarsening_levels', 0, 'Number of coarsened graphs.')\n",
    "\n",
    "flags.DEFINE_string('dir_data', os.path.join('..', 'data', '20news'), 'Directory to store data.')\n",
    "flags.DEFINE_integer('val_size', 400, 'Size of the validation set.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Fetch dataset. Scikit-learn already performs some cleaning.\n",
    "remove = ('headers','footers','quotes')  # (), ('headers') or ('headers','footers','quotes')\n",
    "train = utils.Text20News(data_home=FLAGS.dir_data, subset='train', remove=remove)\n",
    "\n",
    "# Pre-processing: transform everything to a-z and whitespace.\n",
    "print(train.show_document(1)[:400])\n",
    "train.clean_text(num='substitute')\n",
    "\n",
    "# Analyzing / tokenizing: transform documents to bags-of-words.\n",
    "#stop_words = set(sklearn.feature_extraction.text.ENGLISH_STOP_WORDS)\n",
    "# Or stop words from NLTK.\n",
    "# Add e.g. don, ve.\n",
    "train.vectorize(stop_words='english')\n",
    "print(train.show_document(1)[:400])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Remove short documents.\n",
    "train.data_info(True)\n",
    "wc = train.remove_short_documents(nwords=20, vocab='full')\n",
    "train.data_info()\n",
    "print('shortest: {}, longest: {} words'.format(wc.min(), wc.max()))\n",
    "plt.figure(figsize=(17,5))\n",
    "plt.semilogy(wc, '.');\n",
    "\n",
    "# Remove encoded images.\n",
    "def remove_encoded_images(dataset, freq=1e3):\n",
    "    widx = train.vocab.index('ax')\n",
    "    wc = train.data[:,widx].toarray().squeeze()\n",
    "    idx = np.argwhere(wc < freq).squeeze()\n",
    "    dataset.keep_documents(idx)\n",
    "    return wc\n",
    "wc = remove_encoded_images(train)\n",
    "train.data_info()\n",
    "plt.figure(figsize=(17,5))\n",
    "plt.semilogy(wc, '.');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Word embedding\n",
    "if True:\n",
    "    train.embed()\n",
    "else:\n",
    "    train.embed(os.path.join('..', 'data', 'word2vec', 'GoogleNews-vectors-negative300.bin'))\n",
    "train.data_info()\n",
    "# Further feature selection. (TODO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Feature selection.\n",
    "# Other options include: mutual information or document count.\n",
    "freq = train.keep_top_words(1000, 20)\n",
    "train.data_info()\n",
    "train.show_document(1)\n",
    "plt.figure(figsize=(17,5))\n",
    "plt.semilogy(freq);\n",
    "\n",
    "# Remove documents whose signal would be the zero vector.\n",
    "wc = train.remove_short_documents(nwords=5, vocab='selected')\n",
    "train.data_info(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "train.normalize(norm='l1')\n",
    "train.show_document(1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Test dataset.\n",
    "test = utils.Text20News(data_home=FLAGS.dir_data, subset='test', remove=remove)\n",
    "test.clean_text(num='substitute')\n",
    "test.vectorize(vocabulary=train.vocab)\n",
    "test.data_info()\n",
    "wc = test.remove_short_documents(nwords=5, vocab='selected')\n",
    "print('shortest: {}, longest: {} words'.format(wc.min(), wc.max()))\n",
    "test.data_info(True)\n",
    "test.normalize(norm='l1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    train_data = train.data.astype(np.float32)\n",
    "    test_data = test.data.astype(np.float32)\n",
    "    train_labels = train.labels\n",
    "    test_labels = test.labels\n",
    "else:\n",
    "    perm = np.random.RandomState(seed=42).permutation(dataset.data.shape[0])\n",
    "    Ntest = 6695\n",
    "    perm_test = perm[:Ntest]\n",
    "    perm_train = perm[Ntest:]\n",
    "    train_data = train.data[perm_train,:].astype(np.float32)\n",
    "    test_data = train.data[perm_test,:].astype(np.float32)\n",
    "    train_labels = train.labels[perm_train]\n",
    "    test_labels = train.labels[perm_test]\n",
    "\n",
    "if True:\n",
    "    graph_data = train.embeddings.astype(np.float32)\n",
    "else:\n",
    "    graph_data = train.data.T.astype(np.float32).toarray()\n",
    "\n",
    "#del train, test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feature graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "t_start = time.process_time()\n",
    "dist, idx = graph.distance_sklearn_metrics(graph_data, k=FLAGS.number_edges, metric=FLAGS.metric)\n",
    "A = graph.adjacency(dist, idx)\n",
    "print(\"{} > {} edges\".format(A.nnz//2, FLAGS.number_edges*graph_data.shape[0]//2))\n",
    "A = graph.replace_random_edges(A, 0)\n",
    "graphs, perm = coarsening.coarsen(A, levels=FLAGS.coarsening_levels, self_connections=False)\n",
    "L = [graph.laplacian(A, normalized=True) for A in graphs]\n",
    "print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n",
    "#graph.plot_spectrum(L)\n",
    "#del graph_data, A, dist, idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "t_start = time.process_time()\n",
    "train_data = scipy.sparse.csr_matrix(coarsening.perm_data(train_data.toarray(), perm))\n",
    "test_data = scipy.sparse.csr_matrix(coarsening.perm_data(test_data.toarray(), perm))\n",
    "print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n",
    "del perm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Training set is shuffled already.\n",
    "#perm = np.random.permutation(train_data.shape[0])\n",
    "#train_data = train_data[perm,:]\n",
    "#train_labels = train_labels[perm]\n",
    "\n",
    "# Validation set.\n",
    "if False:\n",
    "    val_data = train_data[:FLAGS.val_size,:]\n",
    "    val_labels = train_labels[:FLAGS.val_size]\n",
    "    train_data = train_data[FLAGS.val_size:,:]\n",
    "    train_labels = train_labels[FLAGS.val_size:]\n",
    "else:\n",
    "    val_data = test_data\n",
    "    val_labels = test_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    utils.baseline(train_data, train_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "common = {}\n",
    "common['dir_name']       = '20news/'\n",
    "common['num_epochs']     = 80\n",
    "common['batch_size']     = 100\n",
    "common['decay_steps']    = len(train_labels) / common['batch_size']\n",
    "common['eval_frequency'] = 5 * common['num_epochs']\n",
    "common['filter']         = 'chebyshev5'\n",
    "common['brelu']          = 'b1relu'\n",
    "common['pool']           = 'mpool1'\n",
    "C = max(train_labels) + 1  # number of classes\n",
    "\n",
    "model_perf = utils.model_perf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['regularization'] = 0\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 1e3\n",
    "    params['decay_rate']     = 0.95\n",
    "    params['momentum']       = 0.9\n",
    "    params['F']              = []\n",
    "    params['K']              = []\n",
    "    params['p']              = []\n",
    "    params['M']              = [C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'fc_softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['regularization'] = 0\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 0.1\n",
    "    params['decay_rate']     = 0.95\n",
    "    params['momentum']       = 0.9\n",
    "    params['F']              = []\n",
    "    params['K']              = []\n",
    "    params['p']              = []\n",
    "    params['M']              = [2500, C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'fc_fc_softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['regularization'] = 0\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 0.1\n",
    "    params['decay_rate']     = 0.95\n",
    "    params['momentum']       = 0.9\n",
    "    params['F']              = []\n",
    "    params['K']              = []\n",
    "    params['p']              = []\n",
    "    params['M']              = [2500, 500, C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'fgconv_softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['filter']         = 'fourier'\n",
    "    params['regularization'] = 0\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 0.001\n",
    "    params['decay_rate']     = 1\n",
    "    params['momentum']       = 0\n",
    "    params['F']              = [32]\n",
    "    params['K']              = [L[0].shape[0]]\n",
    "    params['p']              = [1]\n",
    "    params['M']              = [C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'sgconv_softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['filter']         = 'spline'\n",
    "    params['regularization'] = 1e-3\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 0.1\n",
    "    params['decay_rate']     = 0.999\n",
    "    params['momentum']       = 0\n",
    "    params['F']              = [32]\n",
    "    params['K']              = [5]\n",
    "    params['p']              = [1]\n",
    "    params['M']              = [C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'cgconv_softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['regularization'] = 1e-3\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 0.1\n",
    "    params['decay_rate']     = 0.999\n",
    "    params['momentum']       = 0\n",
    "    params['F']              = [32]\n",
    "    params['K']              = [5]\n",
    "    params['p']              = [1]\n",
    "    params['M']              = [C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "if True:\n",
    "    name = 'cgconv_fc_softmax'\n",
    "    params = common.copy()\n",
    "    params['dir_name'] += name\n",
    "    params['regularization'] = 0\n",
    "    params['dropout']        = 1\n",
    "    params['learning_rate']  = 0.1\n",
    "    params['decay_rate']     = 0.999\n",
    "    params['momentum']       = 0\n",
    "    params['F']              = [5]\n",
    "    params['K']              = [15]\n",
    "    params['p']              = [1]\n",
    "    params['M']              = [100, C]\n",
    "    model_perf.test(models.cgcnn(L, **params), name, params,\n",
    "                    train_data, train_labels, val_data, val_labels, test_data, test_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model_perf.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "if False:\n",
    "    grid_params = {}\n",
    "    data = (train_data, train_labels, val_data, val_labels, test_data, test_labels)\n",
    "    utils.grid_search(params, grid_params, *data, model=lambda x: models.cgcnn(L,**x))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
